#!/usr/bin/env python3



import argparse
import os

def numstr(n):
    return '"%s"' % n

# Truncate a Python value to a machine type.
def trunc(n):
    if n is True or n is False:
        return n

    # Truncate to a signed 64-bit integer.
    n &= 0xFFFFFFFFFFFFFFFF;
    if (n & 0x8000000000000000):
        n |= -1 << 63
    return int(n)

def computeInterestingIntegers():
    n = set((1 << (1 << n)) for n in range(7))
    n.add(0)
    n.add(63)

    n2 = set()
    for x in n:
        n2.add(x)
        n2.add(x >> 1)
        n2.add(x << 1)

    n3 = set()
    for x in n2:
        n3.add(x)
        n3.add(x + 2)
        n3.add(x + 1)
        n3.add(x - 1)
        n3.add(x - 2)

    n4 = set()
    for x in n3:
        n4.add(x)
        n4.add(-x)

    return sorted(set(trunc(x) for x in n4))

interesting = computeInterestingIntegers()

class Op(object):
    def __init__(self, name, expr, fn, isDivide = False, isBool = False):
        self.name = name
        self.expr = expr
        self.fn = fn
        self.isDivide = isDivide
        self.isBool = isBool

    def evaluate(self, a, b):
        return str(trunc(self.fn(a, b))).lower()

    def printExpr(self, sk, a, aVar, b, bVar, indent):
        istr = ' ' * indent
        opstr = self.expr % (a, b)

        # At runtime print out a description of the mathematical operation
        # being performed, to aid debugging.
        aArg = '" + %s + "' % aVar if aVar else str(a)
        bArg = '" + %s + "' % bVar if bVar else str(b)
        prefix = '"' + (self.expr % (aArg, bArg)) + ': "'
        if prefix.startswith('"" + '):
            # Optimize away concatenating an empty string.
            prefix = prefix[5:]

        if not self.isDivide or (isinstance(b, int) and b != 0):
            if self.isBool:
                print('%swriteBool(%s, %s);' % (istr, opstr, prefix), file=sk)
            else:
                print('%swriteInt(%s, %s);' % (istr, opstr, prefix), file=sk)
        else:
            # Handle div/rem by zero.
            print('%stry {' % istr, file=sk)
            print('%s  writeInt(%s, %s)' % (istr, opstr, prefix), file=sk)
            print('%s} catch {' % istr, file=sk)
            print('%s  | DivisionByZeroException _ -> '
                  'print_raw(%s + "DivisionByZero\\n")' % (istr, prefix), file=sk)
            print('%s};' % istr, file=sk)

    def noteExpected(self, expected, a, b):
        opstr = self.expr % (a, b)
        if not self.isDivide or b != 0:
            expected.append(opstr + ': ' + self.evaluate(a, b) + '\n')
        else:
            expected.append(opstr + ': DivisionByZero\n')

# Treat a signed 64-bit number as unsigned.
def uns(x):
    return x & 0xFFFFFFFFFFFFFFFF

ops = dict((op.name, op) for op in (
    Op('div', '%s / %s',
       # Python rounds integer division to minus infinity, but Skip
       # rounds towards zero.
       lambda a, b: ((abs(a) / abs(b)) * (-1 if a * b < 0 else 1)) if b else 0,
       isDivide = True),
    Op('rem', '%s %% %s',
       # Python does modulo but we want remainder (ala C).
       lambda a, b: ((abs(a) % abs(b)) * (-1 if a < 0 else 1)) if b else 0,
       isDivide = True),
    Op('add', '%s + %s',  lambda a, b: a + b),
    Op('sub', '%s - %s',  lambda a, b: a - b),
    Op('mul', '%s * %s',  lambda a, b: a * b),
    Op('and', '(%s).and(%s)',  lambda a, b: a & b),
    Op('or', '(%s).or(%s)',  lambda a, b: a | b),
    Op('xor', '(%s).xor(%s)',  lambda a, b: a ^ b),
    Op('shl', '(%s).shl(%s)', lambda a, b: a << (b & 63)),
    Op('ashr', '(%s).shr(%s)', lambda a, b: a >> (b & 63)),
    Op('lshr', '(%s).ushr(%s)',
       lambda a, b: (a & 0xFFFFFFFFFFFFFFFF) >> (b & 63)),
    Op('eq', '%s == %s', lambda a, b: a == b, isBool = True),
    Op('ne', '%s != %s', lambda a, b: a != b, isBool = True),
    Op('lt', '%s < %s',  lambda a, b: a <  b, isBool = True),
    Op('gt', '%s > %s',  lambda a, b: a >  b, isBool = True),
    Op('ge', '%s >= %s', lambda a, b: a >= b, isBool = True),
    Op('le', '%s <= %s', lambda a, b: a <= b, isBool = True),
    Op('ule', '(%s).ule(%s)', lambda a, b: uns(a) <= uns(b), isBool = True),
    Op('ult', '(%s).ult(%s)', lambda a, b: uns(a) <  uns(b), isBool = True),
    Op('ugt', '(%s).ugt(%s)', lambda a, b: uns(a) >  uns(b), isBool = True),
    Op('uge', '(%s).uge(%s)', lambda a, b: uns(a) >= uns(b), isBool = True),
))


def emitVarVar(sk, op, calls, expected):
    funcName = 'do_%s_var_var' % op.name
    calls.append(funcName)
    expected.append('# Starting %s\n' % funcName)

    print(r"""
@no_inline
fun %s(testCases: Array<(Int, String)>): void {
  print_raw("# Starting %s\n");

  testCases.each(ap -> {
    (a, aStr) = ap;
    testCases.each(bp -> {
      (b, bStr) = bp;
    """ % (funcName, funcName), file=sk)

    op.printExpr(sk, 'a', 'aStr', 'b', 'bStr', 6)

    print("""
      void
    })
  })
}
""", file=sk)

    for a in interesting:
        for b in interesting:
            op.noteExpected(expected, a, b)


def emitConstVar(sk, op, calls, expected):
    funcName = 'do_%s_const_var' % op.name
    calls.append(funcName)
    expected.append('# Starting %s\n' % funcName)

    print(r"""
@no_inline
fun %s(testCases: Array<(Int, String)>): void {
  print_raw("# Starting %s\n");

  testCases.each(bp -> {
    (b, bStr) = bp;
    """ % (funcName, funcName), file=sk)

    for a in interesting:
        op.printExpr(sk, a, None, 'b', 'bStr', 4)

    print("""
    void
  })
}
""", file=sk)

    # NOTE: Output order is different here as "a" varies in the inner loop.
    for b in interesting:
        for a in interesting:
            op.noteExpected(expected, a, b)


def emitVarConst(sk, op, calls, expected):
    funcName = 'do_%s_var_const' % op.name
    calls.append(funcName)
    expected.append('# Starting %s\n' % funcName)

    print(r"""
@no_inline
fun %s(testCases: Array<(Int, String)>): void {
  print_raw("# Starting %s\n");

  testCases.each(ap -> {
    (a, aStr) = ap;
    """ % (funcName, funcName), file=sk)

    for b in interesting:
        op.printExpr(sk, 'a', 'aStr', b, None, 4)

    print("""
    void
  })
}
""", file=sk)

    for a in interesting:
        for b in interesting:
            op.noteExpected(expected, a, b)


def emitConstConst(sk, op, calls, expected):
    funcName = 'do_%s_const_const' % op.name
    calls.append(funcName)
    expected.append('# Starting %s\n' % funcName)

    print(r"""
@no_inline
fun %s(_testCases: Array<(Int, String)>): void {
  print_raw("# Starting %s\n");
""" % (funcName, funcName), file=sk)

    for a in interesting:
        for b in interesting:
            op.printExpr(sk, a, None, b, None, 2)
            op.noteExpected(expected, a, b)

    print("""
  void
}
""", file=sk)


def generateOp(sk, exp, op):
    print(r"""
@no_inline
fun writeInt(n: Int, prefix: String): void {
  print_raw(prefix + n.toString() + "\n")
}

@no_inline
fun writeBool(n: Bool, prefix: String): void {
  print_raw(prefix + if (n) { "true\n" } else { "false\n" })
}

""", file=sk)

    # Create all the the test functions.
    calls = []
    expected = []

    op = ops[op]
    emitVarVar(sk, op, calls, expected)
    # Try emitting expressions where one or both inputs are compile-time
    # constants, to test the compiler's constant folding and peepholing.
    # For example, 1 + 2 --> 3, x + 0 -> x, etc.
    emitConstVar(sk, op, calls, expected)
    emitVarConst(sk, op, calls, expected)
    emitConstConst(sk, op, calls, expected)

    print(r"""
fun main(): void {
  testCases = Array[
    %s
  ];

  // NOTE: This "mutable" is to silence a bogus compiler error.
  calls = mutable Array[
    %s
  ];

  calls.each(call -> call(testCases));
}

""" % (',\n    '.join('(%d, "%s")' % (n, n) for n in interesting),
       ',\n    '.join(calls)), file=sk)

    exp.writelines(expected)


def main():
    parser = argparse.ArgumentParser(description='Generate int tests')
    parser.add_argument('--list-ops', action='store_true')
    parser.add_argument('--output-dir')
    parser.add_argument('--op')
    args = parser.parse_args()

    if args.list_ops:
        print(';'.join(sorted(ops)))
        return

    assert args.output_dir, '--output-dir is required'
    assert args.op, '--op is required'

    name = os.path.join(args.output_dir, 'inttests-' + args.op)
    with open(name + '.sk', 'w') as sk, \
         open(name + '.exp', 'w') as exp:
        generateOp(sk, exp, args.op)


main()
